import numpy as np
import config as cfg
import os
import glob
from PIL import Image
from eval_metrics import get_similarity_metric

dataset = ['WM', 'SK']
result_dir = os.path.join(cfg.project_path, 'exps', 'results', 'diffusion_finetune')
model_date = '19-05-2024-00-47-48'
for data in dataset:
    for sub in range(0, 10):
        eval_image_dir = os.path.join(result_dir, data, f'sub_{sub}', model_date, 'eval')
        # find all .png files in eval_image_dir
        png_files = glob.glob(os.path.join(eval_image_dir, '*.png'))
        png_files = [png for png in png_files if 'sample' not in png]
        test_num = len(png_files) / 6
        image_arr = []
        for i in range(int(test_num)):
            image_arr_i = []
            for j in range(6):
                # read image as np array
                image_path = os.path.join(eval_image_dir, f'test{i}-{j}.png')
                image = Image.open(image_path).convert('RGB')
                image = np.array(image)
                image_arr_i.append(image)
            image_arr.append(np.array(image_arr_i))
        image_arr = np.array(image_arr)
        samples = image_arr

        gt_images = [img[0] for img in samples]
        gt_images = np.stack(gt_images)
        # gt_images = rearrange(np.stack(gt_images), 'n c h w -> n h w c')
        samples_to_run = np.arange(1, len(samples[0]))

        res_part_50 = []
        res_part_100 = []
        for s in samples_to_run:
            pred_images = [img[s] for img in samples]
            # pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c')
            pred_images = np.stack(pred_images)
            res = get_similarity_metric(pred_images, gt_images, 'class', None,
                            n_way=50, num_trials=50, top_k=[1, 5], device='cuda')
            res = np.array(res)
            res_part_50.append(res)

            res = get_similarity_metric(pred_images, gt_images, 'class', None,
                            n_way=100, num_trials=50, top_k=[1, 5], device='cuda')
            res = np.array(res)
            res_part_100.append(res)
        res_part_50 = np.array(res_part_50)
        res_part_100 = np.array(res_part_100)
        top1_50 = np.mean(res_part_50[:, :, 0])
        top5_50 = np.mean(res_part_50[:, :, 1])
        top1_100 = np.mean(res_part_100[:, :, 0])
        top5_100 = np.mean(res_part_100[:, :, 1])
        print(f'For {data} sub_{sub}, top-1 50-way acc: {top1_50}, top-5 50-way acc: {top5_50}')
        print(f'For {data} sub_{sub}, top-1 100-way acc: {top1_100}, top-5 100-way acc: {top5_100}')
        np.savez(os.path.join(eval_image_dir, 'n-way_acc.npz'), res_part_50=res_part_50, res_part_100=res_part_100)


